# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tool import common
from train import trainModel
import tensorflow as tf
import threading

def lwh_train_length_thread():

    def tf_train():

        def train(argv):
            
        #     model_fn = mf.regression_denseNet_fn
        #     model_fn = mf.regression_cnn1d_fn
            model_fn = trainModel.regression_vgg64_fn

            feature_columns = [tf.feature_column.numeric_column(key="x", shape=[common.dataProcess_lwh_length.xSize])]
            learning_rate = 1e-4 * 0.8
            epsilon = 1e-6
            step = 1000
            totalStep = 60 * step
            model = tf.estimator.Estimator(model_fn=model_fn, model_dir=common.model_dir,
                                           params={"feature_columns": feature_columns, "learning_rate": learning_rate, 'epsilon':epsilon})
        #=============================================================================
            stepCount = 0
            while stepCount < totalStep:
                if(stepCount == int(totalStep * 0.4) | stepCount == int(totalStep * 0.7)):
                    learning_rate = learning_rate * 0.2
                    model = tf.estimator.Estimator(model_fn=model_fn, model_dir=common.model_dir,
                                           params={"feature_columns": feature_columns, "learning_rate": learning_rate, 'epsilon':epsilon})
                model.train(input_fn=common.dataProcess_lwh_length.input_train, steps=step)
                eval_result = model.evaluate(input_fn=common.dataProcess_lwh_length.input_test)
                stepCount = stepCount + step
                print(str(stepCount) + ": std = {:.2f}mm".format(50000 * eval_result["std"]))

        #=============================================================================
            export_dir = "modelPython"
            model.export_saved_model(export_dir, common.dataProcess_lwh_length.serving_input_receiver_fn, as_text=False)
            print('python model python model export done')

        tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.FATAL)
        tf.compat.v1.app.run(main=train)
    th = threading.Thread(target=tf_train)
    th.start()

def tic_train_length_thread():

    def tf_train():

        def train(argv):
            model_fn = trainModel.regression_vgg48_fn

            feature_columns = [tf.feature_column.numeric_column(key="x", shape=[common.dataProcess_tic_length.xSize])]
            learning_rate = 1e-4 * 0.8
            epsilon = 1e-6
            step = 500
            totalStep = 20 * step
            model = tf.estimator.Estimator(model_fn=model_fn, model_dir=common.model_dir,
                                           params={"feature_columns": feature_columns, "learning_rate": learning_rate, 'epsilon':epsilon})
        #=============================================================================
            stepCount = 0
            while stepCount < totalStep:
                if(stepCount == int(totalStep * 0.4) | stepCount == int(totalStep * 0.7)):
                    learning_rate = learning_rate * 0.2
                    model = tf.estimator.Estimator(model_fn=model_fn, model_dir=common.model_dir,
                                           params={"feature_columns": feature_columns, "learning_rate": learning_rate, 'epsilon':epsilon})
                model.train(input_fn=common.dataProcess_tic_length.input_train, steps=step)
                eval_result = model.evaluate(input_fn=common.dataProcess_tic_length.input_test)
                stepCount = stepCount + step
                print(str(stepCount) + ": std = {:.2f}mm".format(50000 * eval_result["std"]))

        #=============================================================================
            export_dir = "modelPython"
            model.export_saved_model(export_dir, common.dataProcess_tic_length.serving_input_receiver_fn, as_text=False)
            print('python model export done')

        tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.FATAL)
        tf.compat.v1.app.run(main=train)

    th = threading.Thread(target=tf_train)
    th.start()

def tic_train_type_thread():

    def tf_train():

        def train(argv):
            model_fn = trainModel.classification_vgg224_fn
            feature_columns = [tf.compat.v1.feature_column.numeric_column(key="x", shape=[common.dataProcess_tic_type.xSize])]
            # learning_rate = 1e-4
            learning_rate = 0.0006
            epsilon = 1e-6
            # step = 750
            step = 600
            totalStep = step * 15
            ySize = 17
            model = tf.compat.v1.estimator.Estimator(model_fn=model_fn, model_dir=common.model_dir,
                                           params={"feature_columns": feature_columns, "learning_rate": learning_rate, 'epsilon':epsilon, "ySize":ySize})
        #=============================================================================
            stepCount = 0
            while stepCount < totalStep:
                if(stepCount == 2500 or stepCount == 6000):
                    learning_rate = 0.2 * learning_rate
                    model = tf.compat.v1.estimator.Estimator(model_fn=model_fn, model_dir=common.model_dir,
                                       params={"feature_columns": feature_columns, "learning_rate": learning_rate, 'epsilon':epsilon, "ySize":ySize})

                train_result = model.train(input_fn=common.dataProcess_tic_type.input_train, steps=step)
                print(str(stepCount) + ": accuracy = {:.3f}".format(train_result["accuracy"]))
                eval_result = model.evaluate(input_fn=common.dataProcess_tic_type.input_test)
                stepCount = stepCount + step
                print(str(stepCount) + ": accuracy = {:.3f}".format(eval_result["accuracy"]))

        #=============================================================================
            export_dir = "modelPython"
            model.export_saved_model(export_dir, common.dataProcess_tic_type.serving_input_receiver_fn, as_text=False)
            print('python model export done')

        tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.FATAL)
        tf.compat.v1.app.run(main=train)
    th = threading.Thread(target=tf_train)
    th.start()


def frozenGraph_tic_type(metaName='model.ckpt-10000.meta'):
    savaName = 'modelJava/tic1.p'
    meta_path = common.model_dir + '\\' + metaName
    output_node_names = ['output']
    with tf.compat.v1.Session() as sess:
        # Restore the graph
        saver = tf.compat.v1.train.import_meta_graph(meta_path)

        # Load weights
        saver.restore(sess, tf.train.latest_checkpoint(common.model_dir))

        # Freeze the graph
        frozen_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names)

        # Save the frozen graph
        with open(savaName, 'wb') as f:
            f.write(frozen_graph_def.SerializeToString())

    print(savaName + '保存参数完成')


def frozenGraph_tic_length(metaName='model.ckpt-10000.meta'):
    savaName = 'modelJava/tic2.p'
    meta_path = common.model_dir + '\\' + metaName
    output_node_names = ['output']
    with tf.compat.v1.Session() as sess:
        # Restore the graph
        saver = tf.compat.v1.train.import_meta_graph(meta_path)

        # Load weights
        saver.restore(sess, tf.train.latest_checkpoint(common.model_dir))

        # Freeze the graph
        frozen_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names)

        # Save the frozen graph
        with open(savaName, 'wb') as f:
            f.write(frozen_graph_def.SerializeToString())
    print(savaName + '保存参数完成')
    
